{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Q Learning to Play Taxi-v3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "logging.basicConfig(level=logging.INFO,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] env: <TaxiEnv<Taxi-v3>>\n",
      "00:00:00 [INFO] action_space: Discrete(6)\n",
      "00:00:00 [INFO] observation_space: Discrete(500)\n",
      "00:00:00 [INFO] reward_range: (-inf, inf)\n",
      "00:00:00 [INFO] metadata: {'render.modes': ['human', 'ansi']}\n",
      "00:00:00 [INFO] _max_episode_steps: 200\n",
      "00:00:00 [INFO] _elapsed_steps: None\n",
      "00:00:00 [INFO] id: Taxi-v3\n",
      "00:00:00 [INFO] entry_point: gym.envs.toy_text:TaxiEnv\n",
      "00:00:00 [INFO] reward_threshold: 8\n",
      "00:00:00 [INFO] nondeterministic: False\n",
      "00:00:00 [INFO] max_episode_steps: 200\n",
      "00:00:00 [INFO] _kwargs: {}\n",
      "00:00:00 [INFO] _env_name: Taxi\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('Taxi-v3')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class QLearningAgent:\n",
    "    def __init__(self, env):\n",
    "        self.gamma = 0.99\n",
    "        self.learning_rate = 0.2\n",
    "        self.epsilon = 0.01\n",
    "        self.action_n = env.action_space.n\n",
    "        self.q = np.zeros((env.observation_space.n, env.action_space.n))\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        if self.mode == 'train' and np.random.uniform() < self.epsilon:\n",
    "            action = np.random.randint(self.action_n)\n",
    "        else:\n",
    "            action = self.q[observation].argmax()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        state, _, _, action, next_state, reward, done, _ = \\\n",
    "                        self.trajectory[-8:]\n",
    "\n",
    "        v = reward + self.gamma * self.q[next_state].max() * (1. - done)\n",
    "        target = reward + self.gamma * v * (1. - done)\n",
    "        td_error = target - self.q[state, action]\n",
    "        self.q[state, action] += self.learning_rate * td_error\n",
    "\n",
    "\n",
    "agent = QLearningAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "00:00:00 [INFO] ==== train ====\n",
      "00:00:06 [INFO] ==== test ====\n",
      "00:00:06 [INFO] average episode reward = 7.84 ± 2.44\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD4CAYAAAAEhuazAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAciElEQVR4nO3deXxV9Z3/8dfHAGHfV4EYQBYRFSFEcAEBkQg+xNqZipahHXXouLRurZVa+7OdonQZa/21yjCOrf7a+fmjjlWrgoJaHRdGgyJbAQNEE7aELWQh+/f3xzkJN8kNCblJ7nLez8cjD879nnPu+dzLve977vd8z7nmnENERILljGgXICIi7U/hLyISQAp/EZEAUviLiASQwl9EJIA6RLuA5urfv79LTU2NdhkiInFlw4YNh5xzA+q3x034p6amkpmZGe0yRETiipl9Ea5d3T4iIgGk8BcRCSCFv4hIACn8RUQCSOEvIhJACn8RkQBS+IuIBFBgw985x6rMHMorq5tctqrakZVXxK78IgCq/duh8zfmHOPdnfm8tf1gnfssr6zmi8PFOOd44ZNcjpdWsDu/qME2AD4/WMj63YcBKC6rZO+xE2GXO1ZSzl8+21f7OLLyCpv3oIGjxeXkHi3hy8MlYec759hzqJiKqmr/vov42/7jZGYfoaS8kpwjJfzmrc/56Svb+Nma7by2eT9Histr16+sqiYrr6j2+fnt21l8sOsQXx4uobSiisqqanYeLGRzbgGPrdvJnkPFrHhnF18cLuat7QfrPObcoyUUl1VSXe345es7WLftIAA5R0p44q9ZrNt2kPzCMpav3s6WvQUcL63gQEEpu/O9mht7jmus2bKfj7OPUFpRBcD+ghMUllbUzi+rrCL7UDFHiss5VFQGwFvbD7IvpMaS8kp+tmY7Ow8W8uKne/n+85uoqvYuk37weCm5R0tY8c4uisoqa5+fVR/n1C6z82Ah7+7M58vDJfwpM4eyyqqwdb+x9QCrPs7hRPnJ+ZnZR+o8zjVb9pNfWMbbO/LC/v8eKCjl+Q255BWW1mk/XFRGfmFZnbasvEKcc7z5t4O89/khNnxxhC17C8jKK+LVTftrn4+t+wr45MujDZ47gN35RbyfdahBHa9u2s+qzBwOFJRSXe14P+sQBSUVbD9wnJc27iXnSN3ai8oq2XfsBK9t3s/hojKOFpfz6qb9tfPf2n6Qp9/bU3u74EQFL23c602XVPD8hlxWZebUeZ3mHi2hpLySt3fkkXu0hIKSCtbvPkxllffedc7xp8yc2tdGqEN+DTVqXts1qv3/29KKKjbmHOPVTfs5VuItf7S4vPa5rqyqZv3uwxSUVJCZfYTtB7z/y5wjJWzOLeBwURmrN598nK3NonU9fzPLAH4NJAFPOeeWn2r5tLQ011oneR0oKGXFO7v4/QfZfGf2aO6YeTYvfJLL4eJyfvH6jtrl/t+SqVy/cn2rbFMEYMyg7uw8GP7DX6Qxux6eR9IZ1qJ1zWyDcy6tfntUzvA1syTgt8AcIBf42Mxeds5ta+ttF5RUMPWRN2tvHy4qY+HKD/nky2MNllXwS2tT8EtLlFVW0bVT68Z1tLp90oEs59xu51w58BywoD02fKSkvM7tw0XlYYNfRCRWtHbwQ/TCfyiQE3I7129rd2u2HojGZqWeb886u1Xvb8yg7o3OGzuoR4O2TknNeys8/JXzmJTSu6VlxY2hvbtEu4SoGNan7R53uNfNZaP7105/a/rI2umZY09eh+0nC85tk3qiFf7hOq8aHHwwsyVmlmlmmfn5+e1QVuz66qRhzVru6vOH8Ppd07nuwqGsvXs6nzw4hx0/zWiw3Ot3Ta9z+5ZLRzCkV+fa29v/xVtnRP9uzD9vSNhtLZqawr8sOJfs5fPJXj6/zrw37p7Oz//ufG68KAWAcYN71C5X/03ww/nncO+VYxvcz/qls+u8IZ65KZ2bLx0BwP++4UIApqT2oV+3TrXLDO7pPYY1d3qPvX7Q/3D+Obx+93Syl89n98PzAOjVpSM7l13Fnkfmsctvq7Fi0ST++t3LAXjx9ku48aIU5owfXDu/5s372f+6kl0PzyOlb9faeWMH9eDaiWeGfe6AOs/rxh/NaVAnwMAeybVtNc9lqH+YelbY+37htosb3S5Al45JfG/uWDY9dGWd9h/MG8eKRZN4//5ZTB/jBdCGH15B9vL5zB43EIBND13Jw185jz2PzCN7+Xy+M3t0o9t593szefLrk8hePp/nlkzl3jljaufddvkoAC5M6c23po9s8Bqqf78XDO8NwPczxgFw0yUjal8zH9w/i18vnFjn/uv7u8nD+NHV4wG4Pm04y74ygezl83nz3hmA9/p67/uzeOamdFbfeRnrl87msesnkr18Pk8t9rrMr7vw5D7qi7df0ui2aj5Efr1wIu/fP4vs5fN54bZLyF4+n1/+/QUArLnrMv7PzRfVrrN03jn89sZJPHj1eH73j+lk/vAK3r9/FounpTa6nUhE5YCvmU0DHnLOzfVvLwVwzj3S2DqtdcB3z6FiZv7yrxHfT3v63TenMG1UP7btP851T3xQ237txDN5ceO+2tsrFk0iY0L4oE69/9U6t2veaKn3v8riaWfxkwUT6iyXvXw+H2cfYUT/bnTtlERm9lHOGdKTKcvWMWPMAH7x9+czsEfnsNuo/yYOJyuviLe35/FPIeFe46n/3s2nXx7jt1+fdMr7yCssZWCPzpwor6KyupoenTtypLicnQcLmTqyX+1y9676jPQRfbh+SsPwPFZSTpdOSSR3SKpte/bDbB5du5P75o7jhvThmNXdVymtqOKxdZ9z5+zRdOmUVP8u6ygpr+TxN7NYNDWFS3/2NuCF2Iu3XUxVteNf1+7kW9NH0rtrp9rnb8dPM+rUs+dQMaUVVZwzpGftdMekMxjetwvJHZLIKywl50hJ7TrlVdVMSulD9qFiLvdf62/cPZ1nPsimb7dOXHPBmYz2PxSrqx1fXfEBOUdO8PWLUrg7JDyPl1awKaeAS0P2TsOprnb869odTBvZn8G9OnOoqIwJQ3tRVe3o1aVjg+XPeXANXTsl8eHS2Ty6did3zDqb7slet0bu0RLmPPouJyqqyF4+n73HTpB9qBiA84b1YnNuAVNS+/Lo2p3cPnMUPTp3bFDLvoITdE/uwN5jJ/j3d3fzvYxx/GH9F9x2ecPlWyL3aAk7DxYya9wgxj24mtKKar4zezT3zBnD7vwihvTqQlFZJdmHi5mS2rfJ+3tt8366JXdgxpgGV11uFY0d8I1W+HcAdgKzgb3Ax8CNzrmtja3TWuGflVfEFY++E/H9tIVeXTpScKKiQXtomD793h5+8so2vnlxKg9d430dXPJsJm9sO8gfbr6o0TdqaPg//c00Zo0bBHhD2oDagDudAK/vf3YfpnPHpNo9NKmrqed2x4FC9hec4PKxA1ttmyve2cWjb+xk57KrWu0+I1X/NXe682NJZVU1//nRl9yQnkLHZnYdtreYGu3jnKs0szuA1/GGej59quCPVHW1Y/ehYkYN6MbClR+21WYi1jHpDLKXz+eeVRt54ZO9YZdJH+HtScwadzIgHrnuPMYO7sHFo/qFXae+4X1Odk2Ee4Nd1sSeXmMuGtm87QfVunumn3L+2ME9GDu44fGISPzzjFH884xRrXqfkWoq1OMh9Gt0SDqjzbpl2lrUfszFOfca8Fp7bOvf/3s3j6zezuJpZ3GoqLzpFdrJWf268kWYk3F+smACaWf15Qd/3txg3oShvdjzyLw6b5B+3ZO598qxzd7uqb7rtWSPX5rn7IGtG+wikYibX/KKxKf+UM41W2JrZE9j+zfdkztw40UpjB7UPeyZmvG0ZyQisSk2O6kC5ob0hgciAaak9uWrk5s3yud0tPBEQRFJIAr/GHCq4YCt7d45Yxg1oPEx8CISDAr/KFq5OI0b0lM4q183ANqjN+fbs0er20hEghH+NVkXnUvYNW7MoB48ct15td0wUbrGnogEUCAO+NaI2XBthx3xV759KcfDnEMgIsEUqPCvuQZ5zGmHD6UJQ3u1/UZEJG4Eotsnmv7tHyY3e1l1xYtIewlE+EczVOeeO7jJZTr4p4XXXJRMRKStBarbJ1b17daJx66fyMVn6/IIItI+AhH+r22OrTN7w7n2wqj8nIGIBFQgun1ERKQuhX+UbP3x3GiXICIBpvBvJf27d2p6oRDdkgPR4yYiMUrh347O7KXRPCISGxT+LXTeKU6aamzI5tP/OAXwruMvIhJNCv8Wmj6m8V+7+taMhr9LC2DtcR0HEZFmUPhHQcxeY0hEAkPh3450+QYRiRUK/9PQr9vJET3qwhGReKbwPw2R/giKPi5EJFYo/FtBz87NG7Ovrn4RiRUK/2aaktqH+vH9wLxzAFgwse51ebSHLyKxTuHfTEN7d2nQltyx4dN32ej+fG3K8LD3oQ8FEYkVCv9W9ujXJtK1U/huoP7dkwFYNDWlPUsSEWlAF5hpJu9gb2T77t2SO5C9fH7rFCQiEgHt+Z8WHbIVkcSg8BcRCSCFfzO5etdk0Nm6IhLPFP4tNOmsPqe9jj4wRCRWKPxbaObYgdEuQUSkxRT+zRTuUO/AHt51+4f1aXgOgIhILFP4n4b6l2Kee+4gfvfNKdxyWfjr94uIxCqN84+AmTFznLp/RCT+aM+/CQN7eGflhu71d0pq2dOm470iEisU/qehZrTO6rsui24hIiIRiij8zewXZrbdzDaZ2Z/NrHfIvKVmlmVmO8xsbkj7ZDPb7M973CK9SH47cZzc++/ZuWOTyy+ZPpI7Zp7dtkWJiLRQpHv+a4EJzrnzgZ3AUgAzGw8sBM4FMoAnzCzJX+dJYAkw2v/LiLCGdtecj6sfzDuH784d2/bFiIi0QETh75x7wzlX6d9cDwzzpxcAzznnypxze4AsIN3MhgA9nXMfOu+U2WeBayOpIRr0A+wiEu9as8//JmC1Pz0UyAmZl+u3DfWn67eHZWZLzCzTzDLz8/NbsdTTV//yDi0RJz1cIhIATQ71NLN1wOAwsx5wzr3kL/MAUAn8sWa1MMu7U7SH5ZxbCawESEtLi5n9bWW4iMS7JsPfOXfFqeab2TeAq4HZ7uTucS4Q+nNWw4B9fvuwMO0xK2Y+cUREWlGko30ygO8D1zjnSkJmvQwsNLNkMxuBd2D3I+fcfqDQzKb6o3wWAy9FUkN7afpDwPzl9HEhIrEv0j7/3wA9gLVmttHMVgA457YCq4BtwBrgdudclb/OrcBTeAeBd3HyOEFM+uok/4tKE5n+x1su4uZLRzDA/6lGEZFYFtHlHZxzjQ5kd84tA5aFac8EJkSy3faSvXw+f/mseb1SYwf34MGrx7dxRSIirUNn+LYjHScWkVih8G9Czcgep958EUkgCv8mWJj9de3Bi0i8S/jwP1ZS3qL1nr0pHYDzhvYC4JoLTp6Lpm8AIhLvEj78f/NWVovWmzaqHwAp/bqSvXw+GRMGa49fRBJGwod/a4p0j19nBotIrAhk+H9nVmSXWlaGi0i8C2T433OlLrUsIsEWuPD/1fUXtHjday44E4CunfTTxyIS3wIX/l+5cFjTCzXiwavHs+mhK+nSKanphUPMHDugxdsUEWkL2oU9DUlnWLN+wrG+JxdNJu94ma7nLyIxI3B7/s3VmjHduWMSKf26tuI9iohERuEvIhJACn8RkQBS+IuIBJDCX0QkgBT+IiIBpPAXEQkghb+ISAAp/BuhE7JEJJEp/EVEAijhw1878CIiDSV8+Dv95qKISAMJH/4iItKQwl9EJIAU/iIiAaTwFxEJIIV/IzRISEQSmcJfRCSAEj78Nc5fRKShhA9/ERFpKOHDXyd5iYg0lPDhLyIiDSn8RUQCSOEvIhJArRL+ZvZdM3Nm1j+kbamZZZnZDjObG9I+2cw2+/Metxi9cH5sViUi0joiDn8zGw7MAb4MaRsPLATOBTKAJ8wsyZ/9JLAEGO3/ZURag4iInJ7W2PP/FXAfEDquZgHwnHOuzDm3B8gC0s1sCNDTOfehc84BzwLXtkINIiJyGiIKfzO7BtjrnPus3qyhQE7I7Vy/bag/Xb+9sftfYmaZZpaZn5/fwhpbtJqISELr0NQCZrYOGBxm1gPAD4Arw60Wps2doj0s59xKYCVAWlqaRuyLiLSSJsPfOXdFuHYzOw8YAXzmH7MdBnxiZul4e/TDQxYfBuzz24eFaRcRkXbU4m4f59xm59xA51yqcy4VL9gnOecOAC8DC80s2cxG4B3Y/cg5tx8oNLOp/iifxcBLkT+MU9XZlvcuIhKfmtzzbwnn3FYzWwVsAyqB251zVf7sW4HfA12A1f6fiIi0o1YLf3/vP/T2MmBZmOUygQmttd22EqOnH4iItAqd4SsiEkAKfxGRAEr48FfvjYhIQwkf/iIi0pDCX0QkgBT+IiIBlPDhr5O8REQaSvjwFxGRhhT+IiIBFKjw79wxUA9XRKRRCZ+GoeP8LewVpUVEgifhwz+Ua/ynA0REAiVY4a/sFxEBAhb+IiLiCVT46zo/IiKehA9/dfWIiDSU+OEf7QJERGJQwof/xpxjtdMa6iki4kn48N/wxdFolyAiEnMSPvxFRKQhhb+ISAAp/EVEAkjhLyISQIEKf13bR0TEE6jwD7Vi0eSw7fdljOUvd1zaztWIiLSvDtEuoD01Z5z/bZef3Q6ViIhEV6D2/EO7fQb2TI5iJSIi0RWs8A/p8p+U0oerJgyOXjEiIlEUqPCv71fXT4x2CSIiURGo8K9/SefOHZOiU4iISJQFKvxFRMQTqNE+p/LWvTPonqynQ0SCQWnnGzmge7RLEBFpN+r2EREJIIW/iEgAKfxFRAIo4vA3s2+b2Q4z22pmPw9pX2pmWf68uSHtk81ssz/vcbP6AzBFRKStRXTA18xmAguA851zZWY20G8fDywEzgXOBNaZ2RjnXBXwJLAEWA+8BmQAqyOpo7mcLuopIgJEPtrnVmC5c64MwDmX57cvAJ7z2/eYWRaQbmbZQE/n3IcAZvYscC3tFP7hPHNTOgcLSqO1eRGRqIg0/McAl5nZMqAU+K5z7mNgKN6efY1cv63Cn67fHpaZLcH7lkBKSkqEpYY3Y8yANrlfEZFY1mT4m9k6INwV0B7w1+8DTAWmAKvMbCSEvXayO0V7WM65lcBKgLS0tIg7bXR0QUTE02T4O+euaGyemd0KvOCcc8BHZlYN9Mfbox8esugwYJ/fPixMu4iItKNIR/u8CMwCMLMxQCfgEPAysNDMks1sBDAa+Mg5tx8oNLOp/iifxcBLEdbQbDrgKyLiibTP/2ngaTPbApQD3/C/BWw1s1XANqASuN0f6QPeQeLfA13wDvRG7WCviEhQRRT+zrlyYFEj85YBy8K0ZwITItmuiIhERmf4iogEkMJfRCSAFP4iIgEUqPDv0kk/2ygiAgELfxER8Sj8RUQCKFDhr6s7iIh4AhX+IiLiUfiLiASQwl9EJIAU/iIiAaTwFxEJIIW/iEgAKfxFRAJI4S8iEkCBCn/9kJeIiCdQ4S8iIh6Fv4hIACn8RUQCSOEvIhJACn8RkQAKVPjrks4iIp5Ahb+IiHgU/iIiAaTwFxEJoECFv5l6/UVEIGDh75wu8CAiAgELfxER8Sj8RUQCSOEvIhJAgQp/9fiLiHgSPvzvyxgb7RJERGJOwof/xGG9o12CiEjMSfjwnzqyX+20RvmLiHgSPvzPOONk5KvPX0TEE1H4m9lEM1tvZhvNLNPM0kPmLTWzLDPbYWZzQ9onm9lmf97jptNuRUTaXaR7/j8Hfuycmwj8yL+NmY0HFgLnAhnAE2aW5K/zJLAEGO3/ZURYg4iInKZIw98BPf3pXsA+f3oB8Jxzrsw5twfIAtLNbAjQ0zn3ofOutfAscG2ENTSbvmKIiHg6RLj+XcDrZvZLvA+Si/32ocD6kOVy/bYKf7p+e1hmtgTvWwIpKSkRlioiIjWaDH8zWwcMDjPrAWA2cLdz7r/M7GvAfwBXEH4n252iPSzn3EpgJUBaWpqO14qItJImw985d0Vj88zsWeBO/+afgKf86VxgeMiiw/C6hHL96frtIiLSjiLt898HzPCnZwGf+9MvAwvNLNnMRuAd2P3IObcfKDSzqf4on8XASxHWICIipynSPv9/An5tZh2AUvz+eefcVjNbBWwDKoHbnXNV/jq3Ar8HugCr/b92oX4jERFPROHvnHsPmNzIvGXAsjDtmcCESLYrIiKRSfgzfENpqKeIiCdQ4S8iIp5Ahb/6/EVEPIEKfxER8QQq/NXnLyLiCVT4i4iIJ1Dhrz5/ERFPoMJfREQ8gQr/JP1ujIgIEJDw/+m13gnFF6b0jm4hIiIxIhDhf2bvzgAkd0hqYkkRkWCI9MJucWHGmIHcdvkobrlsZLRLERGJCYEI/6QzjPsyxkW7DBGRmBGIbh8REalL4S8iEkAKfxGRAFL4i4gEkMJfRCSAFP4iIgGk8BcRCSCFv4hIAJlz8XGhYzPLB75o4er9gUOtWE57itfa47VuUO3REK91Q+zXfpZzbkD9xrgJ/0iYWaZzLi3adbREvNYer3WDao+GeK0b4rd2dfuIiASQwl9EJICCEv4ro11ABOK19nitG1R7NMRr3RCntQeiz19EROoKyp6/iIiEUPiLiARQQoe/mWWY2Q4zyzKz+6NdD4CZPW1meWa2JaStr5mtNbPP/X/7hMxb6te/w8zmhrRPNrPN/rzHzdr21+nNbLiZvW1mfzOzrWZ2ZxzV3tnMPjKzz/zafxwvtfvbTDKzT83slTirO9vf5kYzy4yz2nub2fNmtt1/zU+Ll9qbzTmXkH9AErALGAl0Aj4DxsdAXdOBScCWkLafA/f70/cDP/Onx/t1JwMj/MeT5M/7CJgGGLAauKqN6x4CTPKnewA7/frioXYDuvvTHYH/AabGQ+3+Nu8B/hN4JV5eL/42s4H+9dripfZngFv86U5A73ipvdmPMdoFtOF/3jTg9ZDbS4Gl0a7LryWVuuG/AxjiTw8BdoSrGXjdf1xDgO0h7TcA/9bOj+ElYE681Q50BT4BLoqH2oFhwJvALE6Gf8zX7W8nm4bhH/O1Az2BPfgDYuKp9tP5S+Run6FATsjtXL8tFg1yzu0H8P8d6Lc39hiG+tP129uFmaUCF+LtQcdF7X7XyUYgD1jrnIuX2h8D7gOqQ9rioW4AB7xhZhvMbInfFg+1jwTygd/53W1PmVm3OKm92RI5/MP1rcXbuNbGHkPUHpuZdQf+C7jLOXf8VIuGaYta7c65KufcRLw96XQzm3CKxWOidjO7Gshzzm1o7iph2qL5ernEOTcJuAq43cymn2LZWKq9A17X7JPOuQuBYrxunsbEUu3NlsjhnwsMD7k9DNgXpVqactDMhgD4/+b57Y09hlx/un57mzKzjnjB/0fn3At+c1zUXsM5dwz4K5BB7Nd+CXCNmWUDzwGzzOwPcVA3AM65ff6/ecCfgXTio/ZcINf/dgjwPN6HQTzU3myJHP4fA6PNbISZdQIWAi9HuabGvAx8w5/+Bl5/ek37QjNLNrMRwGjgI/8rZ6GZTfVHDywOWadN+Nv5D+BvzrlH46z2AWbW25/uAlwBbI/12p1zS51zw5xzqXiv37ecc4tivW4AM+tmZj1qpoErgS3xULtz7gCQY2Zj/abZwLZ4qP20RPugQ1v+AfPwRqXsAh6Idj1+Tf8X2A9U4O0Z3Az0wzuo97n/b9+Q5R/w699ByEgBIA3vzbQL+A31Dk61Qd2X4n1l3QRs9P/mxUnt5wOf+rVvAX7kt8d87SHbvZyTB3xjvm68fvPP/L+tNe+/eKjd3+ZEINN/zbwI9ImX2pv7p8s7iIgEUCJ3+4iISCMU/iIiAaTwFxEJIIW/iEgAKfxFRAJI4S8iEkAKfxGRAPr/t4TDyxEb/ikAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-200:]) > 8:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
